/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.ensemble;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryStruct;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;

@Description(name = "maxrow",
        value = "_FUNC_(ANY compare, ...) - Returns a row that has maximum value in the 1st argument")
public final class MaxRowUDAF extends AbstractGenericUDAFResolver {

    @Override
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
        ObjectInspector oi =
                TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameters[0]);
        if (!ObjectInspectorUtils.compareSupported(oi)) {
            throw new UDFArgumentTypeException(0,
                "Cannot support comparison of map<> type or complex type containing map<>.");
        }
        return new GenericUDAFMaxRowEvaluator();
    }

    @UDFType(distinctLike = true)
    public static class GenericUDAFMaxRowEvaluator extends GenericUDAFEvaluator {

        StructObjectInspector inputStructOI;
        ObjectInspector[] inputOIs;
        ObjectInspector[] outputOIs;

        @Override
        public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException {
            super.init(mode, parameters);

            if (parameters.length == 1 && parameters[0] instanceof StructObjectInspector) {
                return initReduceSide((StructObjectInspector) parameters[0]);
            } else {
                return initMapSide(parameters);
            }
        }

        private ObjectInspector initMapSide(ObjectInspector[] parameters) throws HiveException {
            int length = parameters.length;
            this.inputOIs = parameters;
            this.outputOIs = new ObjectInspector[length];

            List<String> fieldNames = new ArrayList<String>(length);
            List<ObjectInspector> fieldOIs = Arrays.asList(outputOIs);
            for (int i = 0; i < length; i++) {
                fieldNames.add("col" + i);
                outputOIs[i] = ObjectInspectorUtils.getStandardObjectInspector(parameters[i]);
            }

            return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
        }

        private ObjectInspector initReduceSide(StructObjectInspector inputStructOI)
                throws HiveException {
            List<? extends StructField> fields = inputStructOI.getAllStructFieldRefs();
            int length = fields.size();
            this.inputStructOI = inputStructOI;
            this.inputOIs = new ObjectInspector[length];
            this.outputOIs = new ObjectInspector[length];

            for (int i = 0; i < length; i++) {
                StructField field = fields.get(i);
                ObjectInspector oi = field.getFieldObjectInspector();
                inputOIs[i] = oi;
                outputOIs[i] = ObjectInspectorUtils.getStandardObjectInspector(oi);
            }

            return ObjectInspectorUtils.getStandardObjectInspector(inputStructOI);
        }

        static class MaxAgg extends AbstractAggregationBuffer {
            Object[] objects;

            MaxAgg() {
                super();
            }

            void reset() {
                this.objects = null;
            }
        }

        @Override
        public MaxAgg getNewAggregationBuffer() throws HiveException {
            MaxAgg maxagg = new MaxAgg();
            maxagg.reset();
            return maxagg;
        }

        @Override
        public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg)
                throws HiveException {
            MaxAgg maxagg = (MaxAgg) agg;
            maxagg.reset();
        }

        @Override
        public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg,
                Object[] parameters) throws HiveException {
            merge(agg, parameters);
        }

        @Override
        public List<Object> terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg)
                throws HiveException {
            return terminate(agg);
        }

        @Override
        public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial)
                throws HiveException {
            if (partial == null) {
                return;
            }

            final MaxAgg maxagg = (MaxAgg) agg;

            final List<Object> otherObjects;
            if (partial instanceof Object[]) {
                otherObjects = Arrays.asList((Object[]) partial);
            } else if (partial instanceof LazyBinaryStruct) {
                otherObjects = ((LazyBinaryStruct) partial).getFieldsAsList();
            } else if (inputStructOI != null) {
                otherObjects = inputStructOI.getStructFieldsDataAsList(partial);
            } else {
                throw new HiveException("Invalid type: " + partial.getClass().getName());
            }

            boolean isMax = false;
            if (maxagg.objects == null) {
                isMax = true;
            } else {
                int cmp = ObjectInspectorUtils.compare(maxagg.objects[0], outputOIs[0],
                    otherObjects.get(0), inputOIs[0]);
                if (cmp < 0) {
                    isMax = true;
                }
            }

            if (isMax) {
                int length = otherObjects.size();
                maxagg.objects = new Object[length];
                for (int i = 0; i < length; i++) {
                    maxagg.objects[i] = ObjectInspectorUtils.copyToStandardObject(
                        otherObjects.get(i), inputOIs[i]);
                }
            }
        }

        @Override
        public List<Object> terminate(@SuppressWarnings("deprecation") AggregationBuffer agg)
                throws HiveException {
            MaxAgg maxagg = (MaxAgg) agg;
            return Arrays.asList(maxagg.objects);
        }
    }
}
